# import packages
import numpy as np
import pandas as pd
import statsmodels.api as sm
import math
import sys

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
        coef = model.params[pbvarb]
        se = model.bse[pbvarb]
        black_aud_rate = model.params[0] + model.params[1]
        non_black_aud_rate = model.params[0] 
        print("number of obs :" + str(model.nobs))
        return coef, se, black_aud_rate, non_black_aud_rate

## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
        non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
        est =  black_aud_rate - non_black_aud_rate
        return est, black_aud_rate, non_black_aud_rate

## Get standard errors
def getSEs(dataset,pbvarb='predicted_prob_black',outcome='audited', seReg = 0.0):
        seMultiplier=getSEMultiplier(dataset,pbvarb)
        #seReg = regEstimate(dataset,pbvarb,outcome)[1]
        seChen = seReg*seMultiplier
        return seReg, seChen

def getSEMultiplier(dataset,pbvarb='predicted_prob_black'):
        return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))


### OP AUDIT
dataBISG = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final_new_eic.csv")
dataBISG= dataBISG[~dataBISG.predicted_prob_black.isna()]

# define non-dep_database audit
dataBISG['non_dep_database_aud'] = np.where((dataBISG['aud_no_research_audits'] == 1) & (dataBISG['dep_database_aud'] == 0), 1, 0)

# subset to EITC claimants
dataBISG_EITC = dataBISG[dataBISG.isEIC == 1]

##########################################################
#### Disparity Table for dep_database audits
##########################################################
# write out results
sys.stdout = open("/REDACTED/disparity_results_dep_database_eitc_update.txt", "w")

# Calculate dep_database audit disparity
linDisparity = regEstimate(dataBISG_EITC, pbvarb = 'predicted_prob_black', outcome = 'dep_database_aud')
probDisparity = chenEstimate(dataBISG_EITC, pbvarb = 'predicted_prob_black', outcome = 'dep_database_aud')
linSE, probSE = getSEs(dataBISG_EITC, pbvarb = 'predicted_prob_black', outcome = 'dep_database_aud', seReg = float(linDisparity[1]))

dep_database_overall_lin = linDisparity[0]
dep_database_overall_prob = probDisparity[0]
dep_database_aud_rate = dataBISG_EITC.dep_database_aud.mean()

# Calculate non-dep_database audit disparity
linDisparity_non = regEstimate(dataBISG_EITC, pbvarb = 'predicted_prob_black', outcome = 'non_dep_database_aud')
probDisparity_non = chenEstimate(dataBISG_EITC, pbvarb = 'predicted_prob_black', outcome = 'non_dep_database_aud')
linSE_non, probSE_non = getSEs(dataBISG_EITC, pbvarb = 'predicted_prob_black', outcome = 'non_dep_database_aud', seReg = float(linDisparity_non[1]))

non_dep_database_overall_lin = linDisparity_non[0]
non_dep_database_overall_prob = probDisparity_non[0]
non_dep_database_aud_rate = dataBISG_EITC.non_dep_database_aud.mean()


print("dep_database Audits")
print("Overall Linear Disparity: "+str(dep_database_overall_lin))
print("Linear Standard Error: "+str(linSE))
print("Overall Probabilistic Disparity: "+str(dep_database_overall_prob))
print("Probabilistic Standard Error: "+str(probSE)+"\n")
print("Audit Rate: "+str(dep_database_aud_rate)+"\n")


print("Non-dep_database Audits")
print("Overall Linear Disparity: "+str(non_dep_database_overall_lin))
print("Linear Standard Error: "+str(linSE_non))
print("Overall Probabilistic Disparity: "+str(non_dep_database_overall_prob))
print("Probabilistic Standard Error: "+str(probSE_non)+"\n")
print("Audit Rate: "+str(non_dep_database_aud_rate)+"\n")

print("Share of Disparity Attributable to dep_database")
print("Linear: " +str(dep_database_overall_lin/(dep_database_overall_lin + non_dep_database_overall_lin)))
print("Probabilistic: " +str(dep_database_overall_prob/(dep_database_overall_prob+non_dep_database_overall_prob)))

# Calculate share of EITC audits selected through the dep_database program
dataBISG_EITC_aud = dataBISG_EITC[dataBISG_EITC.aud_no_research_audits == 1]
eitc_share_dep_database = dataBISG_EITC_aud.dep_database_aud.mean()
print("\n(For caption) Share of EITC audits selected through dep_database: " + str(eitc_share_dep_database))

sys.stdout = sys.__stdout__
